{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "epochs = 50"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 7 - 使用FederatedDataset进行联邦学习\n",
    "\n",
    "在这里，我们介绍了一种使用联合数据集的新工具。 我们创建了一个`FederatedDataset`类，该类旨在像PyTorch Dataset类一样使用，并提供给联合数据加载器`FederatedDataLoader`，它将以联合方式对其进行迭代。\n",
    "\n",
    "作者:\n",
    "- Andrew Trask - Twitter: [@iamtrask](https://twitter.com/iamtrask)\n",
    "- Théo Ryffel - GitHub: [@LaRiffle](https://github.com/LaRiffle)\n",
    "\n",
    "中文版译者：\n",
    "- Hou Wei - github：[@dljgs1](https://github.com/dljgs1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们使用上上节课提到的沙箱"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as th\n",
    "import syft as sy\n",
    "sy.create_sandbox(globals(), verbose=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "找一个数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "boston_data = grid.search(\"#boston\", \"#data\")\n",
    "boston_target = grid.search(\"#boston\", \"#target\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "加载模型和优化器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_features = boston_data['alice'][0].shape[1]\n",
    "n_targets = 1\n",
    "\n",
    "model = th.nn.Linear(n_features, n_targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在这里，我们将获取的数据转换为`FederatedDataset`中的数据。看看拥有部分数据的工作机。\n",
    "Here we cast the data fetched in a `FederatedDataset`. See the workers which hold part of the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将结果转换为BaseDatasets\n",
    "datasets = []\n",
    "for worker in boston_data.keys():\n",
    "    dataset = sy.BaseDataset(boston_data[worker][0], boston_target[worker][0])\n",
    "    datasets.append(dataset)\n",
    "\n",
    "# 建立FederatedDataset对象\n",
    "dataset = sy.FederatedDataset(datasets)\n",
    "print(dataset.workers)\n",
    "optimizers = {}\n",
    "for worker in dataset.workers:\n",
    "    optimizers[worker] = th.optim.Adam(params=model.parameters(),lr=1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "放入一个`FederatedDataLoader`并进行设置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = sy.FederatedDataLoader(dataset, batch_size=32, shuffle=False, drop_last=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后，我们在各个epoch进行迭代。 您会发现这与纯本地PyTorch训练相比有多相似！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(1, epochs + 1):\n",
    "    loss_accum = 0\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        model.send(data.location)\n",
    "        \n",
    "        optimizer = optimizers[data.location.id]\n",
    "        optimizer.zero_grad()\n",
    "        pred = model(data)\n",
    "        loss = ((pred.view(-1) - target)**2).mean()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        model.get()\n",
    "        loss = loss.get()\n",
    "        \n",
    "        loss_accum += float(loss)\n",
    "        \n",
    "        if batch_idx % 8 == 0:\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tBatch loss: {:.6f}'.format(\n",
    "                epoch, batch_idx, len(train_loader),\n",
    "                       100. * batch_idx / len(train_loader), loss.item()))            \n",
    "            \n",
    "    print('Total loss', loss_accum)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 恭喜!!! 是时候加入社区了!\n",
    "\n",
    "祝贺您完成本笔记本教程！ 如果您喜欢此方法，并希望加入保护隐私、去中心化AI和AI供应链（数据）所有权的运动，则可以通过以下方式做到这一点！\n",
    "\n",
    "### 给 PySyft 加星\n",
    "\n",
    "帮助我们的社区的最简单方法是仅通过给GitHub存储库加注星标！ 这有助于提高人们对我们正在构建的出色工具的认识。\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "\n",
    "### 加入我们的 Slack!\n",
    "\n",
    "保持最新进展的最佳方法是加入我们的社区！ 您可以通过填写以下表格来做到这一点[http://slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### 加入代码项目!\n",
    "\n",
    "对我们的社区做出贡献的最好方法是成为代码贡献者！ 您随时可以转到PySyft GitHub的Issue页面并过滤“projects”。这将向您显示所有概述，选择您可以加入的项目！如果您不想加入项目，但是想做一些编码，则还可以通过搜索标记为“good first issue”的GitHub问题来寻找更多的“一次性”微型项目。\n",
    "\n",
    "- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### 捐赠\n",
    "\n",
    "如果您没有时间为我们的代码库做贡献，但仍想提供支持，那么您也可以成为Open Collective的支持者。所有捐款都将用于我们的网络托管和其他社区支出，例如黑客马拉松和聚会！\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
